# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #  
#
#' @title  Induce labelings from crowd-sourced elite criticism codings for 
#'          tweets in sample from first round of annotation
#' @author Hauke Licht
#
# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #

# setup ----
set.seed(1234)

# load packages
library(readr)
library(dplyr)
library(tidyr)
library(purrr)

library(ggplot2)
library(ggridges)

library(AnnotationModelsR)

base_path <- file.path(".")
data_path <- file.path(base_path, "data")
fits_path <- file.path(data_path, "fits")
codings_path <- file.path(data_path, "intermediate", "codings")

source(file.path(base_path, "code", "helpers", "utils.R"))

# load codings ----

codings <- read_rds(file.path(codings_path, "codings_sample_1.rds"))

# inspect descriptives ----

label_abbr <- c(
  "yes-general" = '"General"'
  , "yes-specific" = '"Specific"'
  , "yes-unsure" = '"Ambigious"'
  , "no" = '"No"'
  , "cannot-answer" = '"Cannot answer"'
)


# label distribution
count(codings, coding) 

# compute and inspect coder stats ----

# coder stats
coder_stats <- codings %>% 
  group_by(worker_id) %>% 
  summarize(
    n_judgments = n_distinct(item_id)
    , judgment_entropy = entropy(coding)
    , judgment_set = list(rev(sort(table(coding))))
    , mean_duration = mean(duration)
    , median_duration = median(duration)
    , q10_duration = quantile(duration, .10)
    , q25_duration = quantile(duration, .25)
    , q75_duration = quantile(duration, .75)
    , q90_duration = quantile(duration, .90)
    , badbadnotgood = case_when(
      n_judgments <= 5 ~ unique(worker_id)
      , n_judgments > 5 & judgment_entropy < .25  ~ unique(worker_id)
      , n_judgments > 50 & judgment_entropy < .5  ~ unique(worker_id)
      , n_judgments > 100 & judgment_entropy < 1  ~ unique(worker_id)
      , TRUE ~ NA_character_
    ) 
  ) %>% 
  ungroup()

J <- length(with(codings, unique(worker_id)))
ggplot(count(codings, worker_id), aes( x = n)) +   
  geom_histogram(bins = floor(J/4), fill = "black") + 
  scale_x_continuous(trans = "log10") +
  labs(
    title = "Distribution of judgments contributed per coder"
    , x = "Number of judgments contributed (log-10 scale)"
    , y = "Number of coders"
  ) 

codings %>% 
  group_by(worker_id) %>% 
  summarise(start = min(start_time), end = max(end_time)) %>% 
  ggplot(aes(x = reorder(worker_id, desc(start)), ymin = start, ymax = end, group = worker_id)) + 
  geom_linerange() +
  coord_flip() +
  scale_x_discrete(breaks = NULL) +
  labs(
    title = "Coder activity times and duration"
    , subtitle = "Each horizontal line represents the activity time span and duration of a single coder"
    , x = "Date and time"
    , y = "Coders"
  )

codings %>% 
  group_by(worker_id) %>% 
  mutate(m_ = mean(duration)) %>% 
  ggplot(aes(y = duration, x = reorder(worker_id, m_), group = worker_id)) + 
  geom_boxplot(
    fill = NA, alpha = .75
    # , width = 4
    , outlier.colour = "darkgrey", outlier.size = .25, outlier.alpha = .5
  ) +
  coord_flip() +
  scale_x_discrete(name = NULL, breaks = NULL) +
  scale_y_continuous(
    name = "Seconds (base-2 logarithmic scale)"
    , trans = "log2"
  ) + 
  labs(
    title = "Distributions of judgments times across coders"
  )

coder_stats %>% 
  ggplot(aes(n_judgments, median_duration)) + 
    geom_smooth(color = "black") +
    geom_jitter(alpha = .5, height = .01, width = .01) + 
    scale_y_continuous(trans = "log10") +
    scale_x_continuous(trans = "log10") +
    labs(
      title = "Relationship between number of judgments and median per-judgment duration."
      , subtitle = "1% horizontal and vertical jitter added to aid readibility."
      , caption = "Note that values are depicted on log-10 scales."
      , x = "No. judgments"
      , y = "Median per-judgment duration"
    ) + 
    theme(plot.caption = element_text(hjust = 0))

coder_stats %>%   
  ggplot(aes(n_judgments, judgment_entropy)) +
    geom_jitter(pch = 1, alpha = .9, width = .01, height = .01) +
    scale_x_continuous(trans = "log10") + 
    labs(
      title = "Number of judgments against coder judgment entropy"
      , subtitle = "Vertical and horizontal jitter of max. 1% added to avoid over-plotting"
      , x = "Number of judgments (base-10 log scale)"
      , y = "Coder judgment entropy"
    ) 

# compute and inspect tweet stats ----

tweet_stats <- codings %>% 
  group_by(item_id, status_id, source) %>% 
  summarize(
    median_duration = median(duration)
    , mean_duration = mean(duration)
    , n_judgments = n_distinct(worker_id)
    , judgment_entropy = entropy(coding)
    , judgments = list(rev(sort(table(coding))))
    , mode_label = map_chr(map(judgments, names), 1)
    , mode_label_n = map_int(judgments, 1)
    , mode_is_tie = map(judgments, 2)
    , mode_is_tie = ifelse(map_lgl(mode_is_tie, is.null), 0L, map_int(mode_is_tie, 1))
    , mode_is_tie = mode_label_n == mode_is_tie
  ) %>% 
  ungroup()

# coding duration summary by label class
codings %>% 
  group_by(coding) %>% 
  summarize(
    Mean = mean(duration, na.rm = TRUE)
    , `Std.Dev.` = sd(duration)
    , `10%-quantile` = quantile(duration, .1)
    , `25%-quantile` = quantile(duration, .25)
    , Median = median(duration, na.rm = TRUE)
    , `75%-quantile` = quantile(duration, .75)
    , `90%-quantile` = quantile(duration, .99)
    , Skewness = e1071::skewness(duration)
  ) %>% 
  ungroup()

# coding entropy by majority winner label by label class
tweet_stats %>% 
  count(mode_label, mode_label_n, Entropy = sprintf("%0.3f", judgment_entropy)) %>% 
  group_by(mode_label) %>% 
  mutate(Proportion = n/sum(n)) %>% 
  ggplot(aes(x = Entropy , y = Proportion, fill = as.factor(mode_label_n))) + 
  geom_bar(stat = "identity", alpha = .9) +
  facet_grid(cols = vars(factor(mode_label, names(label_abbr), label_abbr))) + 
  scale_fill_brewer(type = "seq", palette = 2) + 
  labs(
    x = "\nTweet-level judgment entropy"
    , y = "Relative proportions"
    , fill = "Tweet-level count of most frequent label among judgments"
  ) + 
  theme(
    legend.position = "top"
    , axis.text.x = element_text(angle = 30, hjust = 1, vjust = 1)
  )

# filter "bad" coders ----

bad_coders <- filter(coder_stats, !is.na(badbadnotgood))$worker_id
(n_bad_workers <- length(bad_coders))

n_bad_removed <- codings %>% 
  filter(worker_id %in% bad_coders) %>% 
  nrow()

# visualize
coder_stats %>%   
  ggplot(aes(n_judgments, judgment_entropy, color = !is.na(badbadnotgood), shape = !is.na(badbadnotgood))) +
    geom_jitter(alpha = .9, size = 1, width = .01, height = .01) +
    scale_x_continuous(trans = "log10") + 
    scale_color_manual(breaks = c(F, T), values = c("black", "red")) +
    scale_shape_manual(breaks = c(F, T), values = c(1, 3)) +
    guides(shape = "none") + 
    labs(
      title = "Number of judgments against coder judgment entropy"
      , subtitle = "Vertical and horizontal jitter of max. 1% added to avoid over-plotting"
      , x = "Number of judgments (base-10 log scale)"
      , y = "Coder judgment entropy"
      , color = "Contributions removed"
    ) + 
    theme(legend.position = "bottom")


# remove
codings_no_bad_coders <- codings %>% 
  filter(
    !is.na(coding)
    , !worker_id %in% bad_coders
    , duration >= 4
  )

label_counts <- table(codings_no_bad_coders$coding, useNA = "always")
count(codings_no_bad_coders, coding )

# distribution of remaining judgments per tweet
codings_no_bad_coders %>% 
  group_by(item_id) %>%
  summarise(n_judgments = n()) %>% 
  count(n_judgments)

# aggregate annotations with Dawid-Skene per annotator model ----

em_fit <- em(
  codings_no_bad_coders
  , item.col = item_id
  , annotator.col = worker_id
  , label.col = coding
  , max.iters = 500
  , .prevalence.prior = as.array(prop.table(table(codings_no_bad_coders$coding)))
  , .min.relative.diff = 0.0001
)

## inspect estimated parameters ----

# label class prevalence estimates vs. prevalence of model-induced labelings and plurality voting labels 
em_fit$est_class_prevl
# check "Value" section in ?AnnotationModelsR::em for explanations

# confusion matrix of model-induced labelings (rows) and plurality voting labels (columns)
em_fit$est_class_probs %>%
  count(labeling, majority_vote) %>% 
  pivot_wider(names_from = "majority_vote", values_from = "n")

# annotator ability estimates
em_fit$est_annotator_params %>% 
  filter(coding == labeled) %>% 
  ggplot(aes(est_prob)) + 
  facet_grid(cols = vars(factor(coding, names(label_abbr), label_abbr))) +
  geom_density(alpha = .75, fill = "black", color = NA) +
  scale_x_continuous(breaks = seq(.2, 1, .2), limit = 0:1) +
  labs(
    title = "Coder true-positive detection ability estimates"
    , subtitle = 'Refined model estimates (after dropping "spam" coders)'
    , x = NULL
    , y = "Density"
  )

## extract labelings ----

labelings <- em_fit$est_class_probs

# save labeling
labeled_tweets <- codings %>% 
  select(
    item_id
    , country_iso3c, party_id, party_name_short
    , user_id, status_id
    , text_en = `source`
    , prob_political, cluster_id
  ) %>% 
  unique() %>% 
  left_join(select(tweet_stats, item_id, median_duration, judgment_entropy), by = "item_id") %>% 
  right_join(labelings, by = "item_id")

fp <- file.path(data_path, "intermediate", "labelings", "dawidskene_labelings_sample_1.csv")
if (!file.exists(fp))
  write_csv(labeled_tweets, fp)
